{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "4b1f5dfd",
   "metadata": {},
   "outputs": [],
   "source": [
    "pip install pytorch_lightning --quiet"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "c30510fe",
   "metadata": {
    "executionInfo": {
     "elapsed": 20,
     "status": "ok",
     "timestamp": 1634657848021,
     "user": {
      "displayName": "Amit Joglekar",
      "photoUrl": "https://lh3.googleusercontent.com/a/default-user=s64",
      "userId": "13549820997843161316"
     },
     "user_tz": 300
    },
    "id": "c30510fe"
   },
   "outputs": [],
   "source": [
    "import pickle\n",
    "import numpy as np\n",
    "from PIL import Image\n",
    "import matplotlib.pyplot as plt\n",
    "\n",
    "import torchvision.transforms as transforms\n",
    "\n",
    "from model import HybridModel\n",
    "from vocabulary import Vocabulary"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "e764a90a",
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/",
     "height": 286
    },
    "executionInfo": {
     "elapsed": 3360,
     "status": "ok",
     "timestamp": 1634658090168,
     "user": {
      "displayName": "Amit Joglekar",
      "photoUrl": "https://lh3.googleusercontent.com/a/default-user=s64",
      "userId": "13549820997843161316"
     },
     "user_tz": 300
    },
    "id": "e764a90a",
    "outputId": "f70a4ee0-3258-4e16-b820-6cefb09a717f"
   },
   "outputs": [],
   "source": [
    "def load_image(image_file_path, transform=None):\n",
    "    img = Image.open(image_file_path).convert('RGB')\n",
    "    img = img.resize([224, 224], Image.LANCZOS)\n",
    "    plt.imshow(np.asarray(img))\n",
    "    if transform is not None:\n",
    "        img = transform(img).unsqueeze(0)\n",
    "    return img\n",
    "\n",
    "image_file_path = 'sample.png'\n",
    "transform = transforms.Compose([\n",
    "    transforms.ToTensor(),\n",
    "    transforms.Normalize((0.485, 0.456, 0.406), \n",
    "                         (0.229, 0.224, 0.225))])\n",
    "img = load_image(image_file_path, transform)\n",
    "\n",
    "hybrid_model = HybridModel.load_from_checkpoint(\"lightning_logs/version_0/checkpoints/epoch=4-step=784.ckpt\")\n",
    "token_ints = hybrid_model.get_caption(img)\n",
    "token_ints = token_ints[0].cpu().numpy()\n",
    "\n",
    "with open('coco_data/vocabulary.pkl', 'rb') as f:\n",
    "    vocabulary = pickle.load(f)\n",
    "predicted_caption = []\n",
    "for token_int in token_ints:\n",
    "    token = vocabulary.int_to_token[token_int]\n",
    "    predicted_caption.append(token)\n",
    "    if token == '<end>':\n",
    "        break\n",
    "predicted_sentence = ' '.join(predicted_caption)\n",
    "\n",
    "print(predicted_sentence)"
   ]
  }
 ],
 "metadata": {
  "colab": {
   "collapsed_sections": [],
   "name": "cnn_rnn_predict_4_classes_scratch.ipynb",
   "provenance": [
    {
     "file_id": "1cWAuKVTRcV6POT0TzraL86l_qkaMw2-5",
     "timestamp": 1633816960272
    },
    {
     "file_id": "1fZx54u3ivULEmzXbH5y7BJv6wH2t-gD5",
     "timestamp": 1633694633210
    },
    {
     "file_id": "1p8xu-fWIhdE3wCHhGSEADuVP2YQ2I4LC",
     "timestamp": 1633406184483
    },
    {
     "file_id": "1qnmuz4y8AcgPEAc6vfwshdFyt-SnJqOb",
     "timestamp": 1633242673747
    }
   ]
  },
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.8.8"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
